"""
Various utils

"""
import os
import sys
from time import time
import datetime
import io
import socket

import pickle
import yaml
import torch
import torchmetrics.functional as MF
import torchmetrics
import numpy as np
import scipy.stats as st


def get_host_name():
    return socket.gethostname()

def get_current_time():
    _TIME_ZONE = 0
    stamp = time()
    formated_time = datetime.datetime.fromtimestamp(
        int(stamp)+_TIME_ZONE*3600).strftime('%Y%m%d-%H%M%S')
    return  formated_time

class Logger(object):
    """
    Print hook in order to print the output to both 
    the terminal and some log file. 
    """
    def __init__(self, log_path):
        self.terminal = sys.stdout
        self.log = open(log_path, "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
        self.flush()

    def flush(self):
        self.terminal.flush()
        self.log.flush()

def load_sdmp_conf_with_default(conf_path):
    DEFAULT_PATH = os.path.dirname(os.path.abspath(__file__))
    DEFAULT_PATH = os.path.join(DEFAULT_PATH, 'config/sdmp_default.yml')
    return load_train_conf(conf_path, DEFAULT_PATH)

def load_mlp_conf_with_default(conf_path):
    DEFAULT_PATH = os.path.dirname(os.path.abspath(__file__))
    DEFAULT_PATH = os.path.join(DEFAULT_PATH, 'config/mlp_default.yml')
    return load_train_conf(conf_path, DEFAULT_PATH)

def load_train_conf(conf_path, default_path=None):
    with open(conf_path, 'r') as fin:
        train_conf = yaml.full_load(fin)
    if default_path is not None:
        with open(default_path, 'r') as fin:
            default_conf = yaml.full_load(fin)
        default_conf.update(train_conf)
        train_conf = default_conf
    return train_conf

def export_train_conf(conf_path, conf):
    with open(conf_path, 'w') as fout:
        yaml.dump(conf, fout, default_flow_style=False, sort_keys=False)

# Torch metrics
def torch_acc(y_hats, ys):
    return MF.accuracy(y_hats, ys)

def torch_f1(y_hats, ys, task='multilabel', average='micro'):
    # num_class = len(Counter(ys.detach().cpu().numpy()).keys())
    num_class = int(y_hats.shape[1])
    f1_score = torchmetrics.F1Score(num_classes=num_class, task='multilabel', average="micro").to(y_hats.device)
    return f1_score(y_hats, ys)

def evaluate(model, dataloader, metric_fn=torch_f1):
    """Evaluate the plain dataloader"""
    model.eval()
    ys = []
    y_hats = []
    for it, (x, y) in enumerate(dataloader):
        with torch.no_grad():
            ys.append(y)
            y_hats.append(model(x))
    return metric_fn(torch.cat(y_hats), torch.cat(ys))

def evaluate_with_time(model, dataloader, metric_fn=torch_f1):
    """Evaluate the plain dataloader"""
    model.eval()
    ys = []
    y_hats = []
    sample_time, infer_time = [], []
    tic = time()
    for it, (x, y) in enumerate(dataloader):
        sample_time.append(time() - tic)
        tic = time()
        with torch.no_grad():
            ys.append(y)
            y_hats.append(model(x))
        infer_time.apend(time() - tic)
        tic = time()
    return metric_fn(torch.cat(y_hats), torch.cat(ys)), sample_time, infer_time
# end of Torch metrics

############### torch utils from LargeST
# https://github.com/liuxu77/LargeST
def masked_mse(preds, labels, null_val):
    if torch.isnan(null_val):
        mask = ~torch.isnan(labels)
    else:
        mask = (labels != null_val)
    mask = mask.float()
    mask /= torch.mean((mask))
    mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)
    loss = (preds - labels)**2
    loss = loss * mask
    loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)
    return torch.mean(loss)


def masked_rmse(preds, labels, null_val):
    return torch.sqrt(masked_mse(preds=preds, labels=labels, null_val=null_val))


def masked_mae(preds, labels, null_val):
    if torch.isnan(null_val):
        mask = ~torch.isnan(labels)
    else:
        mask = (labels != null_val)
    mask = mask.float()
    mask /= torch.mean((mask))
    mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)
    loss = torch.abs(preds - labels)
    loss = loss * mask
    loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)
    return torch.mean(loss)


def masked_mape(preds, labels, null_val):
    if torch.isnan(null_val):
        mask = ~torch.isnan(labels)
    else:
        mask = (labels != null_val)
    mask = mask.float()
    mask /= torch.mean((mask))
    mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)
    loss = torch.abs(preds - labels) / labels
    loss = loss * mask
    loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)
    return torch.mean(loss)


def compute_all_metrics(preds, labels, null_val):
    mae = masked_mae(preds, labels, null_val).item()
    mape = masked_mape(preds, labels, null_val).item()
    rmse = masked_rmse(preds, labels, null_val).item()
    return mae, mape, rmse

class StandardScaler():
    def __init__(self, mean, std):
        self.mean = torch.tensor(mean)
        self.std = torch.tensor(std)

    def transform(self, data):
        return (data - self.mean) / self.std

    def inverse_transform(self, data):
        return (data * self.std) + self.mean

class LargeSTLossWrapper:
    """
    Class to handle the regression inverse transformation, loss and metrics
    """
    def __init__(self, transform_stats):
        self.transform_stats = transform_stats
        self.scaler = StandardScaler(transform_stats[0],
                                     transform_stats[1])
        self.mask_value_trace = []

        self.mse_loss_fn = torch.nn.MSELoss()
        self.mae_loss_fn = torch.nn.L1Loss()

    def _inverse_transform(self, tensors):
        def inv(tensor):
            return self.scaler.inverse_transform(tensor)

        if isinstance(tensors, list):
            return [inv(tensor) for tensor in tensors]
        else:
            return inv(tensors)
    
    def __call__(self, preds, labels):
        rescale_loss, origin_loss, stats = [], [], []

        target_shape = labels.shape
        pred_shape = preds.shape
        if target_shape != pred_shape:
            preds = preds.reshape(target_shape)

        # compute the loss under the transformed IO
        rescale_mse = self.mse_loss_fn(preds, labels)
        rescale_mae = self.mae_loss_fn(preds, labels)
        rescale_loss = [rescale_mse, rescale_mae]
        
        # inverse transform and compute the loss and metrics with the original IO
        preds, labels = self._inverse_transform([preds, labels])
        mask_value = torch.tensor(0)
        if labels.min() < 1:
            mask_value = labels.min()
        self.mask_value_trace.append(mask_value)
        mse = masked_mse(preds, labels, mask_value)
        mae = masked_mae(preds, labels, mask_value)
        mape = masked_mape(preds, labels, mask_value)
        rmse = masked_rmse(preds, labels, mask_value)
        
        stats = [mae.item(), mape.item(), rmse.item()]
        origin_loss = [mse, mae]

        return stats, origin_loss, rescale_loss

    def all_horizon(self, preds, labels):
        """
        preds and labels should be list and each entry conressponding to 1 horizon
        """
        all_mae, all_rmse, all_mape = [], [], []

        for each_pred, each_label in zip(preds, labels):
            stats, _, _ = self.__call__(each_pred, each_label)
            all_mae.append(stats[0])
            all_rmse.append(stats[2])
            all_mape.append(stats[1])
        return all_mae, all_mape, all_rmse

############### end of torch utils from LargeST

def fast_numpy_slicing(A, rows, cols):
    """
    Fast version to get A[rows, :][:, cols]
    """
    n_r, n_c = A.shape
    idx = rows.reshape(-1,1)*n_c + cols
    res = A.take(idx.flat).reshape(n_r, n_c)
    return res

def torch_batch_matrix_mul_matrix_list(a, list_b, batch_size=1024, device="cpu"):
    """
    Compute the matrix multiplication with torch
    input:
    a: the first matrix
    list_b: either a list of matrices or a single matrix
    """
    if isinstance(list_b) is not list:
        list_b = [list_b]
    if batch_size == -1:
        batch_size = a.shape[0]
    max_batch = int(np.ceil(a.shape[0] / batch_size))
    with torch.no_grad():
        list_torch_b = [torch.from_numpy(b).float().to(device) for b in list_b]
        list_res = []
        for i in range(max_batch):
            torch_cur_a = torch.from_numpy(a[i*batch_size:(i+1)*batch_size, :]).float().to(device)
            cur_product = torch_cur_a
            for torch_b in list_torch_b:
                cur_product @= torch_b
            list_res.append(cur_product.cpu().detach().numpy())
    res = np.concatenate(list_res, axis=0)
    return res

def torch_batch_matrix_list_mul_matrix(list_a, b, batch_size=1024, device="cpu"):
    """
    Compute the matrix multiplication with torch
    input:
    list_a: either a list of matrices or a single matrix
    b: the last matrix
    """
    if isinstance(list_a) is not list:
        list_a = [list_a]
    if batch_size == -1:
        batch_size = b.shape[1]
    max_batch = int(np.ceil(b.shape[1] / batch_size))
    with torch.no_grad():
        list_torch_a = [torch.from_numpy(a).float().to(device) for a in list_a]
        list_res = []
        for i in range(max_batch):
            torch_cur_b = torch.from_numpy(b[:, i*batch_size:(i+1)*batch_size]).float().to(device)
            cur_product = torch_cur_b
            for torch_a in reversed(list_torch_a):
                cur_product = torch_a @ cur_product
            list_res.append(cur_product.cpu().detach().numpy())
    res = np.concatenate(list_res, axis=1)
    return res

def np_wthresh(A: np.array, lam: float) -> np.array:
    sign = np.sign(A)
    val = np.abs(A)

    B = A.copy()

    zero_pos = np.where(val<=lam)
    B[zero_pos] =0

    shrink_pos = np.where(val>lam)
    B[shrink_pos] = (val[shrink_pos] - lam) * sign[shrink_pos]

    return B

def torch_wthresh(A: torch.tensor, lam: torch.tensor) -> torch.tensor:
    sign = torch.sign(A)
    abs_val = torch.abs(A)

    opt_thresh = torch.nn.Threshold(0, 0)
    return sign * opt_thresh(abs_val - lam)

def confidence_interval(data, mean=None, sem=None, data_len=None, confidence=0.95):
    mean = np.mean(data) if mean is None else mean
    sem = st.sem(data) if sem is None else sem
    data_len = len(data) - 1 if data_len is None else data_len - 1
    tmp = st.t.interval(alpha=confidence,
                        df=data_len,
                        loc=mean,
                        scale=sem) 
    m = np.mean(tmp)
    h = (tmp[1] - tmp[0])/2
    return m, h

def latex_sample_mean_std_confidence(res, confidence=0.95):
    mean = np.mean(res)
    std = np.std(res)
    _, interval = confidence_interval(res, confidence=confidence)
    res_str = "Mean and std: {:.2f}$\pm${:.2f}\n".format(mean*100, std*100)
    res_str += f"Mean and {confidence:.2f} interval: {mean*100:.2f}$\pm${interval*100:.2f}\n"
    return res_str

def table_ready_mean_confidence(res, confidence=0.95, scale=100):
    mean = np.mean(res)
    std = np.std(res)
    _, interval = confidence_interval(res, confidence=confidence)
    return f"{mean*scale:.2f}$\pm${interval*scale:.2f}"

def table_ready_mean_std(res, scale=100):
    mean = np.mean(res)
    std = np.std(res)
    return f"{mean*scale:.2f}$\pm${std*scale:.2f}"

class CPU_Unpickler(pickle.Unpickler):
    """
    Solve the inconsistent CPU GPU location issue of the saved model
    From https://stackoverflow.com/questions/57081727/load-pickle-file-obtained-from-gpu-to-cpu
    """
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else:
            return super().find_class(module, name)
